{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from pytorch_model_summary import summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how an autoregressive model (ARM) works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ARM code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please see the blogpost for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CausalConv1d(nn.Module):\n",
    "    \"\"\"\n",
    "    A causal 1D convolution.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, dilation, A=False, **kwargs):\n",
    "        super(CausalConv1d, self).__init__()\n",
    "\n",
    "        # attributes:\n",
    "        self.kernel_size = kernel_size\n",
    "        self.dilation = dilation\n",
    "        self.A = A\n",
    "        \n",
    "        self.padding = (kernel_size - 1) * dilation + A * 1\n",
    "\n",
    "        # module:\n",
    "        self.conv1d = torch.nn.Conv1d(in_channels, out_channels,\n",
    "                                      kernel_size, stride=1,\n",
    "                                      padding=0,\n",
    "                                      dilation=dilation,\n",
    "                                      **kwargs)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.nn.functional.pad(x, (self.padding, 0))\n",
    "        conv1d_out = self.conv1d(x)\n",
    "        if self.A:\n",
    "            return conv1d_out[:, :, : -1]\n",
    "        else:\n",
    "            return conv1d_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPS = 1.e-5\n",
    "\n",
    "def log_categorical(x, p, num_classes=256, reduction=None, dim=None):\n",
    "    x_one_hot = F.one_hot(x.long(), num_classes=num_classes)\n",
    "    log_p = x_one_hot * torch.log(torch.clamp(p, EPS, 1. - EPS))\n",
    "    if reduction == 'avg':\n",
    "        return torch.mean(log_p, dim)\n",
    "    elif reduction == 'sum':\n",
    "        return torch.sum(log_p, dim)\n",
    "    else:\n",
    "        return log_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ARM(nn.Module):\n",
    "    def __init__(self, net, D=2, num_vals=256):\n",
    "        super(ARM, self).__init__()\n",
    "\n",
    "        print('ARM by JT.')\n",
    "\n",
    "        self.net = net\n",
    "        self.num_vals = num_vals\n",
    "        self.D = D\n",
    "\n",
    "    def f(self, x):\n",
    "        h = self.net(x.unsqueeze(1))\n",
    "\n",
    "        h = h.permute(0, 2, 1)\n",
    "        p = torch.softmax(h, 2)\n",
    "        return p\n",
    "        \n",
    "    def forward(self, x, reduction='avg'):\n",
    "        if reduction == 'avg':\n",
    "            return -(self.log_prob(x).mean())\n",
    "        elif reduction == 'sum':\n",
    "            return -(self.log_prob(x).sum())\n",
    "        else:\n",
    "            raise ValueError('reduction could be either `avg` or `sum`.')\n",
    "\n",
    "    def log_prob(self, x):\n",
    "        mu_d = self.f(x)\n",
    "        log_p = log_categorical(x, mu_d, num_classes=self.num_vals, reduction='sum', dim=-1).sum(-1)\n",
    "        \n",
    "        return log_p\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        x_new = torch.zeros((batch_size, self.D))\n",
    "\n",
    "        for d in range(self.D):\n",
    "            p = self.f(x_new)\n",
    "            x_new_d = torch.multinomial(p[:, d, :], num_samples=1)\n",
    "            x_new[:, d] = x_new_d[:,0]\n",
    "\n",
    "        return x_new"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary functions: training, evaluation, plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's rather self-explanatory, isn't it?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = model_best.forward(test_batch, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}')\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = next(iter(test_loader)).detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name=''):\n",
    "    x = next(iter(data_loader)).detach().numpy()\n",
    "\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def plot_curve(name, nll_val):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('nll')\n",
    "    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            if hasattr(model, 'dequantization'):\n",
    "                if model.dequantization:\n",
    "                    batch = batch + torch.rand(batch.shape)\n",
    "            loss = model.forward(batch)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "\n",
    "                samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "\n",
    "    return nll_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train')\n",
    "val_data = Digits(mode='val')\n",
    "test_data = Digits(mode='test')\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)\n",
    "\n",
    "result_dir = 'results/'\n",
    "if not(os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)\n",
    "name = 'arm'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = 64   # input dimension\n",
    "M = 256  # the number of neurons in scale (s) and translation (t) nets\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize ARM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "likelihood_type = 'categorical'\n",
    "\n",
    "num_vals = 17\n",
    "\n",
    "kernel = 7\n",
    "\n",
    "net = nn.Sequential(\n",
    "    CausalConv1d(in_channels=1, out_channels=M, dilation=1, kernel_size=kernel, A=True, bias=True),\n",
    "    nn.LeakyReLU(),\n",
    "    CausalConv1d(in_channels=M, out_channels=M, dilation=1, kernel_size=kernel, A=False, bias=True),\n",
    "    nn.LeakyReLU(),\n",
    "    CausalConv1d(in_channels=M, out_channels=M, dilation=1, kernel_size=kernel, A=False, bias=True),\n",
    "    nn.LeakyReLU(),\n",
    "    CausalConv1d(in_channels=M, out_channels=num_vals, dilation=1, kernel_size=kernel, A=False, bias=True))\n",
    "\n",
    "model = ARM(net, D=D, num_vals=num_vals)\n",
    "\n",
    "# Print the summary (like in Keras)\n",
    "print(summary(model, torch.zeros(1, 64), show_input=False, show_hierarchical=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's play! Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_loss))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
